import torch

def test1():
    data1 = torch.tensor([[1,2],[3,4],[5,6]])
    data2 = torch.tensor([[5,6],[7,8]])

    data = data1 @ data2
    print(data)

def test2():
    #使用mm函数，需要保证张量只能是二维
    data1 = torch.tensor([[1,2],[3,4],[5,6]])
    data2 = torch.tensor([[5,6],[7,8]])

    data = torch.mm(data1, data2)
    print(data)

def test3():
    #使用bmm运算，需要保证是三维
    data1 = torch.randn(3 , 4, 5)
    data2 = torch.randn(3 , 5, 8)
    data = torch.bmm(data1, data2)
    print(data.shape)

def test4():
    # matmul运算
    data1 = torch.randn(4, 5)
    data2 = torch.randn(5, 8)
    print(torch.matmul(data1, data2))

    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)
    print(torch.matmul(data1, data2))


if __name__ == '__main__':
    test1()
    test2()
    test3()
    test4()